{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 1.0\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = mz.models.Bert.get_default_preprocessor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with encode: 100%|██████████| 2118/2118 [00:00<00:00, 4920.66it/s]\n",
      "Processing text_right with encode: 100%|██████████| 18841/18841 [00:11<00:00, 1574.02it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 688806.38it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 969547.18it/s]\n",
      "Processing text_left with encode: 100%|██████████| 122/122 [00:00<00:00, 7249.90it/s]\n",
      "Processing text_right with encode: 100%|██████████| 1115/1115 [00:00<00:00, 1583.57it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 222577.25it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 698946.19it/s]\n",
      "Processing text_left with encode: 100%|██████████| 237/237 [00:00<00:00, 6706.45it/s]\n",
      "Processing text_right with encode: 100%|██████████| 2300/2300 [00:01<00:00, 1709.31it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 315171.23it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 569743.63it/s]\n"
     ]
    }
   ],
   "source": [
    "train_pack_processed = preprocessor.transform(train_pack_raw)\n",
    "dev_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset = mz.dataloader.Dataset(\n",
    "    data_pack=train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=2,\n",
    "    num_neg=1\n",
    ")\n",
    "testset = mz.dataloader.Dataset(\n",
    "    data_pack=test_pack_processed\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "padding_callback = mz.models.Bert.get_default_padding_callback()\n",
    "trainloader = mz.dataloader.DataLoader(\n",
    "    dataset=trainset,\n",
    "    batch_size=20,\n",
    "    stage='train',\n",
    "    resample=True,\n",
    "    sort=False,\n",
    "    callback=padding_callback\n",
    ")\n",
    "testloader = mz.dataloader.DataLoader(\n",
    "    dataset=testset,\n",
    "    batch_size=20,\n",
    "    stage='dev',\n",
    "    callback=padding_callback\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bert(\n",
      "  (bert): BertModule(\n",
      "    (bert): BertModel(\n",
      "      (embeddings): BertEmbeddings(\n",
      "        (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
      "        (position_embeddings): Embedding(512, 768)\n",
      "        (token_type_embeddings): Embedding(2, 768)\n",
      "        (LayerNorm): BertLayerNorm()\n",
      "        (dropout): Dropout(p=0.1, inplace=False)\n",
      "      )\n",
      "      (encoder): BertEncoder(\n",
      "        (layer): ModuleList(\n",
      "          (0): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (1): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (2): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (3): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (4): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (5): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (6): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (7): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (8): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (9): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (10): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "          (11): BertLayer(\n",
      "            (attention): BertAttention(\n",
      "              (self): BertSelfAttention(\n",
      "                (query): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (key): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (value): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "              (output): BertSelfOutput(\n",
      "                (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "                (LayerNorm): BertLayerNorm()\n",
      "                (dropout): Dropout(p=0.1, inplace=False)\n",
      "              )\n",
      "            )\n",
      "            (intermediate): BertIntermediate(\n",
      "              (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
      "            )\n",
      "            (output): BertOutput(\n",
      "              (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
      "              (LayerNorm): BertLayerNorm()\n",
      "              (dropout): Dropout(p=0.1, inplace=False)\n",
      "            )\n",
      "          )\n",
      "        )\n",
      "      )\n",
      "      (pooler): BertPooler(\n",
      "        (dense): Linear(in_features=768, out_features=768, bias=True)\n",
      "        (activation): Tanh()\n",
      "      )\n",
      "    )\n",
      "  )\n",
      "  (dropout): Dropout(p=0.2, inplace=False)\n",
      "  (out): Linear(in_features=768, out_features=1, bias=True)\n",
      ")\n",
      "Trainable params:  109483009\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.Bert()\n",
    "\n",
    "model.params['task'] = ranking_task\n",
    "model.params['mode'] = 'bert-base-uncased'\n",
    "model.params['dropout_rate'] = 0.2\n",
    "\n",
    "model.build()\n",
    "\n",
    "print(model)\n",
    "print('Trainable params: ', sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "no_decay = ['bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 5e-5},\n",
    "    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "]\n",
    "\n",
    "from pytorch_transformers import AdamW, WarmupLinearSchedule\n",
    "\n",
    "optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, betas=(0.9, 0.98), eps=1e-8)\n",
    "scheduler = WarmupLinearSchedule(optimizer, warmup_steps=6, t_total=-1)\n",
    "\n",
    "trainer = mz.trainers.Trainer(\n",
    "    model=model,\n",
    "    optimizer=optimizer,\n",
    "    scheduler=scheduler,\n",
    "    trainloader=trainloader,\n",
    "    validloader=testloader,\n",
    "    validate_interval=None,\n",
    "    epochs=10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "80339848290e43e6b16fcf23910b00ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-102 Loss-0.976]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5273 - normalized_discounted_cumulative_gain@5(0.0): 0.6022 - mean_average_precision(0.0): 0.5671\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aaaa89b4da994af1b153999661187c0d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-204 Loss-0.657]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.664 - normalized_discounted_cumulative_gain@5(0.0): 0.7286 - mean_average_precision(0.0): 0.6804\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8a34039178046acb5e3e3af4b907289",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-306 Loss-0.487]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.686 - normalized_discounted_cumulative_gain@5(0.0): 0.7388 - mean_average_precision(0.0): 0.6925\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ac53bda88e14fe5b4db916f7e7bebeb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-408 Loss-0.374]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7179 - normalized_discounted_cumulative_gain@5(0.0): 0.7489 - mean_average_precision(0.0): 0.705\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1670080947ae4ded8c048b24c166b209",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-510 Loss-0.293]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7224 - normalized_discounted_cumulative_gain@5(0.0): 0.7576 - mean_average_precision(0.0): 0.7132\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8003d038581d4db885f0f84f6a9a1d1f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-612 Loss-0.182]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7139 - normalized_discounted_cumulative_gain@5(0.0): 0.7453 - mean_average_precision(0.0): 0.7059\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4aa622124aa6427181032c67f9172318",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-714 Loss-0.098]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7139 - normalized_discounted_cumulative_gain@5(0.0): 0.7453 - mean_average_precision(0.0): 0.7059\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b6f89f6bf734aa991f59ce2d81aadbb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-816 Loss-0.094]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7139 - normalized_discounted_cumulative_gain@5(0.0): 0.7453 - mean_average_precision(0.0): 0.7059\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b8e6d243a8b47c49c99d47062a0d33c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-918 Loss-0.103]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7139 - normalized_discounted_cumulative_gain@5(0.0): 0.7453 - mean_average_precision(0.0): 0.7059\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "54f2db26cc8045eda6bcc228f401cb42",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=102), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-1020 Loss-0.103]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.7139 - normalized_discounted_cumulative_gain@5(0.0): 0.7453 - mean_average_precision(0.0): 0.7059\n",
      "\n",
      "Cost time: 10848.315711736679s\n"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
